LFP example

import altair as alt
from bayes_window import BayesWindow, models, BayesRegression, LMERegression
from bayes_window.generative_models import generate_fake_lfp

try:
    alt.renderers.enable('altair_saver', fmts=['png'])
except Exception:
    pass
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/jax/experimental/optimizers.py:30: FutureWarning: jax.experimental.optimizers is deprecated, import jax.example_libraries.optimizers instead
  FutureWarning)
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/jax/experimental/stax.py:30: FutureWarning: jax.experimental.stax is deprecated, import jax.example_libraries.stax instead
  FutureWarning)

Make and visualize model oscillation power

40 trials of “theta power” is generated for every animal. It is drawn randomly as a poisson process.

This is repeated for “stimulation” trials, but poisson rate is higher.

# Draw some fake data:
df, df_monster, index_cols, _ = generate_fake_lfp(mouse_response_slope=15, n_trials=30)

Mice vary in their baseline power.

Higher-baseline mice tend to have smaller stim response:

BayesWindow(df=df, y='Log power', treatment='stim', group='mouse').plot(x='mouse').facet(column='stim')
BayesWindow(df=df, y='Log power', treatment='stim', group='mouse', detail='i_trial').data_box_detail().facet(
    column='mouse')

Fit a Bayesian hierarchical model and plot slopes

In a hierarchical model, parameters are viewed as a sample from a population distribution of parameters. Thus, we view them as being neither entirely different or exactly the same. This is partial pooling:

hierarchical This model allows intercepts to vary across mouse, according to a random effect. We just add a fixed slope for the predictor (i.e all mice will have the same slope):

\[y_i = \alpha_{j[i]} + \beta x_{i} + \epsilon_i\]

where:

  • \(j\) is mouse index

  • \(i\) is observation index

  • \(y_i\) is observed power

  • \(x_i\) is 0 (no stimulation) or 1 (stimulation)

  • \(\epsilon_i \sim N(0, \sigma_y^2)\), error

  • \(\alpha_{j[i]} \sim N(\mu_{\alpha}, \sigma_{\alpha}^2)\), Random intercept

We set a separate intercept for each mouse, but rather than fitting separate regression models for each mouse, multilevel modeling shares strength among mice, allowing for more reasonable inference in mice with little data.

The wrappers in this library allow us to fit and plot this inference in just three lines of code. Under the hood, it uses the following Numpyro code:

# Given: y, treatment, group, n_subjects
# Sample intercepts
a = sample('a', Normal(0, 1))
a_subject = sample('a_subject', Normal(jnp.tile(0, n_subjects), 1))

# Sample variances
sigma_a_subject = sample('sigma_a', HalfNormal(1))
sigma_obs = sample('sigma_obs', HalfNormal(1))

# Sample slope - this is what we are interested in!
b = sample('b_stim', Normal(0, 1))

# Regression equation
theta = a + a_subject[group] * sigma_a_subject + b * treatment

# Sample power
sample('y', Normal(theta, sigma_obs), obs=y)

Above is the contents of model_hier_stim_one_codition.py, the function passed as argument in line 4 below.

# Initialize:
window = BayesRegression(df=df, y='Power', treatment='stim', group='mouse')
# Fit:
window.fit(model=models.model_hierarchical, add_group_intercept=True,
           add_group_slope=False, robust_slopes=False,
           do_make_change='subtract', dist_y='gamma')

chart_power_difference = (window.chart + window.chart_posterior_kde).properties(title='Posterior')
chart_power_difference

In this chart:

  • The black line is the 94% posterior highest density interval

  • Shading is posterior density

# TODO diff_y is missing from data_and posterior
# chart_power_difference_box
window.data_and_posterior.rename({'Power': 'Power diff'}, axis=1, inplace=True)
# window.plot(x=':O',independent_axes=True).properties(title='Posterior')
window.chart

In this chart:

  • The blue dot is the mean of posterior

  • The black line is the 94% highest density interval

  • The boxplot is made from difference between groups in the data (no fitting)

  • Left Y scale is for posterior, right for data

Compare to non-bayesian approaches

Off-the-shelf OLS ANOVA

ANOVA does not pick up the effect of stim as significant:

window = LMERegression(df=df, y='Log power', treatment='stim', group='mouse')
window.fit();
Using formula Log_power ~  C(stim, Treatment) + (1 | mouse)
                         Coef. Std.Err.       z  P>|z| [0.025 0.975]
Intercept                1.954    0.039  50.684  0.000  1.879  2.030
C(stim, Treatment)[T.1]  0.153    0.033   4.600  0.000  0.088  0.218
1 | mouse                0.051    0.008   6.680  0.000  0.036  0.066
Group Var                0.000    0.004                             
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/statsmodels/base/model.py:606: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals
  ConvergenceWarning)
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/statsmodels/regression/mixed_linear_model.py:2202: ConvergenceWarning: Retrying MixedLM optimization with lbfgs
  ConvergenceWarning)
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/statsmodels/regression/mixed_linear_model.py:2237: ConvergenceWarning: The MLE may be on the boundary of the parameter space.
  warnings.warn(msg, ConvergenceWarning)

Including mouse as predictor helps

window.fit(formula='Log_power ~ stim + mouse + mouse*stim');
Using formula Log_power ~ stim + mouse + mouse*stim
             Coef. Std.Err.       z  P>|z|  [0.025 0.975]
Intercept    1.944    0.043  44.734  0.000   1.859  2.029
stim         0.225    0.060   3.722  0.000   0.107  0.343
mouse        0.061    0.010   5.864  0.000   0.041  0.081
mouse:stim  -0.021    0.014  -1.427  0.154  -0.049  0.008
Group Var    0.000    0.007                              
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/statsmodels/base/model.py:606: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals
  ConvergenceWarning)
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/statsmodels/regression/mixed_linear_model.py:2202: ConvergenceWarning: Retrying MixedLM optimization with lbfgs
  ConvergenceWarning)
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/statsmodels/regression/mixed_linear_model.py:2237: ConvergenceWarning: The MLE may be on the boundary of the parameter space.
  warnings.warn(msg, ConvergenceWarning)

OLS ANOVA with heteroscedasticity correction

window.fit(formula='Log_power ~ stim + mouse + mouse*stim', robust="hc3");
Using formula Log_power ~ stim + mouse + mouse*stim
             Coef. Std.Err.       z  P>|z|  [0.025 0.975]
Intercept    1.944    0.043  44.734  0.000   1.859  2.029
stim         0.225    0.060   3.722  0.000   0.107  0.343
mouse        0.061    0.010   5.864  0.000   0.041  0.081
mouse:stim  -0.021    0.014  -1.427  0.154  -0.049  0.008
Group Var    0.000    0.007                              
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/statsmodels/base/model.py:606: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals
  ConvergenceWarning)
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/statsmodels/regression/mixed_linear_model.py:2202: ConvergenceWarning: Retrying MixedLM optimization with lbfgs
  ConvergenceWarning)
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/statsmodels/regression/mixed_linear_model.py:2237: ConvergenceWarning: The MLE may be on the boundary of the parameter space.
  warnings.warn(msg, ConvergenceWarning)

A linear mixed-effect model shows the effect of stim (slope) as significant. It includes random intercepts of mouse:

# Initialize:
window = LMERegression(df=df, y='Log power', treatment='stim', group='mouse')
window.fit(add_data=False);
Using formula Log_power ~  C(stim, Treatment) + (1 | mouse)
                         Coef. Std.Err.       z  P>|z| [0.025 0.975]
Intercept                1.954    0.039  50.684  0.000  1.879  2.030
C(stim, Treatment)[T.1]  0.153    0.033   4.600  0.000  0.088  0.218
1 | mouse                0.051    0.008   6.680  0.000  0.036  0.066
Group Var                0.000    0.004                             
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/statsmodels/base/model.py:606: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals
  ConvergenceWarning)
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/statsmodels/regression/mixed_linear_model.py:2202: ConvergenceWarning: Retrying MixedLM optimization with lbfgs
  ConvergenceWarning)
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/statsmodels/regression/mixed_linear_model.py:2237: ConvergenceWarning: The MLE may be on the boundary of the parameter space.
  warnings.warn(msg, ConvergenceWarning)
chart_power_difference_lme = window.plot().properties(title='LME')
chart_power_difference_lme

Compare LME and Bayesian slopes side by side

chart_power_difference | chart_power_difference_lme

Inspect Bayesian result further

Let’s take a look at the intercepts and compare them to levels of power in the original data:

# Initialize:
window = BayesRegression(df=df, y='Power', treatment='stim', group='mouse', detail='i_trial')
# Fit:
window.fit(model=models.model_hierarchical, add_group_intercept=True,
           add_group_slope=False, robust_slopes=False,
           do_make_change='subtract', dist_y='gamma');

chart_detail_and_intercepts = window.plot_intercepts(x='mouse')
window.chart_posterior_intercept
chart_detail_and_intercepts

Our plotting backend’s flexibility allows us to easily concatenate multiple charts in the same figures with the | operator:

window.chart_posterior_intercept | chart_power_difference | chart_power_difference_lme

Check for false-positives with null model

They sometimes appear with non-transformed data + “normal” model

# Initialize:
df_null, df_monster_null, _, _ = generate_fake_lfp(mouse_response_slope=0, n_trials=30)
window = BayesRegression(df=df_null, y='Power', treatment='stim', group='mouse')
# Fit:
window.fit(model=models.model_hierarchical, add_group_intercept=True,
           add_group_slope=False, robust_slopes=False,
           do_make_change='subtract', dist_y='normal')

# Plot:
chart_power_difference = window.plot(independent_axes=False,
                                     ).properties(title='Posterior')

chart_power_difference
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
/tmp/ipykernel_3991/1927712031.py in <module>
      5 window.fit(model=models.model_hierarchical, add_group_intercept=True,
      6            add_group_slope=False, robust_slopes=False,
----> 7            do_make_change='subtract', dist_y='normal')
      8 
      9 # Plot:

/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/bayes_window/slopes.py in fit(self, model, do_make_change, fold_change_index_cols, do_mean_over_trials, fit_method, add_condition_slope, **kwargs)
     76                                 model=model,
     77                                 add_condition_slope=add_condition_slope,
---> 78                                 **kwargs)
     79         df_data = self.window.data.copy()
     80         if do_mean_over_trials:

/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/bayes_window/fitting.py in fit_numpyro(progress_bar, model, num_warmup, n_draws, num_chains, convert_to_arviz, sampler, use_gpu, **kwargs)
     52                 chain_method='parallel'
     53                 )
---> 54     mcmc.run(jax.random.PRNGKey(16), **kwargs)
     55 
     56     # arviz convert

/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    574         else:
    575             if self.chain_method == "sequential":
--> 576                 states, last_state = _laxmap(partial_map_fn, map_args)
    577             elif self.chain_method == "parallel":
    578                 states, last_state = pmap(partial_map_fn)(map_args)

/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/numpyro/infer/mcmc.py in _laxmap(f, xs)
    160     for i in range(n):
    161         x = jit(_get_value_from_index)(xs, i)
--> 162         ys.append(f(x))
    163 
    164     return tree_multimap(lambda *args: jnp.stack(args), *ys)

/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
    393             progbar_desc=partial(_get_progbar_desc_str, lower_idx, phase),
    394             diagnostics_fn=diagnostics,
--> 395             num_chains=self.num_chains if self.chain_method == "parallel" else 1,
    396         )
    397         states, last_val = collect_vals

/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/numpyro/util.py in fori_collect(lower, upper, body_fun, init_val, transform, progbar, return_last_val, collection_size, thinning, **progbar_opts)
    332     if not progbar:
    333         last_val, collection, _, _ = fori_loop(
--> 334             0, upper, _body_fn, (init_val, collection, start_idx, thinning)
    335         )
    336     elif num_chains > 1:

/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/numpyro/util.py in fori_loop(lower, upper, body_fun, init_val)
    137         return val
    138     else:
--> 139         return lax.fori_loop(lower, upper, body_fun, init_val)
    140 
    141 

    [... skipping hidden 1 frame]

/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/jax/_src/lax/control_flow.py in fori_loop(lower, upper, body_fun, init_val)
    216 
    217     (_, result), _ = scan(_fori_scan_body_fun(body_fun), (lower_, init_val),
--> 218                           None, length=upper_ - lower_)
    219   else:
    220     _, _, result = while_loop(_fori_cond_fun, _fori_body_fun(body_fun),

    [... skipping hidden 1 frame]

/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/jax/_src/lax/control_flow.py in scan(f, init, xs, length, reverse, unroll)
   1359                     num_consts=len(consts), num_carry=len(init_flat),
   1360                     linear=(False,) * (len(consts) + len(in_flat)),
-> 1361                     unroll=unroll)
   1362   return tree_unflatten(out_tree, out)
   1363 

/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/jax/_src/lax/control_flow.py in scan_bind(*args, **params)
   1931     _scan_typecheck(True, *avals, **params)
   1932     core.check_jaxpr(params['jaxpr'].jaxpr)
-> 1933   return core.AxisPrimitive.bind(scan_p, *args, **params)
   1934 
   1935 scan_p = core.AxisPrimitive("scan")

/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/jax/core.py in bind(self, *args, **params)
    270         args, used_axis_names(self, params) if self._dispatch_on_params else None)
    271     tracers = map(top_trace.full_raise, args)
--> 272     out = top_trace.process_primitive(self, tracers, params)
    273     return map(full_lower, out) if self.multiple_results else full_lower(out)
    274 

/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/jax/core.py in process_primitive(self, primitive, tracers, params)
    622 
    623   def process_primitive(self, primitive, tracers, params):
--> 624     return primitive.impl(*tracers, **params)
    625 
    626   def process_call(self, primitive, f, tracers, params):

/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/jax/interpreters/xla.py in apply_primitive(prim, *args, **params)
    416   compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
    417                                         **params)
--> 418   return compiled_fun(*args)
    419 
    420 

/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/jax/interpreters/xla.py in _execute_compiled(name, compiled, output_buffer_counts, handlers, kept_var_idx, *args)
   1098           for i, x in enumerate(args)
   1099           if x is not token and i in kept_var_idx))
-> 1100   out_bufs = compiled.execute(input_bufs)
   1101   check_special(name, out_bufs)
   1102   if output_buffer_counts is None:

KeyboardInterrupt: 

This does not happen if we estimate group slopes.

GLM is more robust to no differences in the case of no effect:

# Initialize:
window = BayesRegression(df=df_null, y='Power', treatment='stim', group='mouse')
# Fit:
window.fit(model=models.model_hierarchical, add_group_intercept=True,
           add_group_slope=False, robust_slopes=False,
           do_make_change='subtract', dist_y='gamma')
# Plot:
window.plot(independent_axes=False,
            ).properties(title='Posterior')

Include all samples in each trial

The mean of every one of the 30 trials we drew for each mouse is a manifestation of the same underlying process that generates power for each mouse. Let’s try to include all samples that come in each trial

# NBVAL_SKIP
# Initialize:
window = BayesRegression(df=df_monster, y='Power', treatment='stim', group='mouse')
# Fit:
window.fit(model=models.model_hierarchical, add_group_intercept=True,
           num_warmup=500, n_draws=160, progress_bar=True,
           add_group_slope=False, robust_slopes=False,
           do_make_change='subtract', dist_y='gamma');
# NBVAL_SKIP
alt.data_transformers.disable_max_rows()
chart_power_difference_monster = window.plot(independent_axes=False).properties(title='Posterior')
chart_power_difference_monster

Much tighter credible intervals here!

Same with linear mixed model:

# NBVAL_SKIP
window = BayesRegression(df=df_monster,
                         y='Log power', treatment='stim', group='mouse')
window.fit()

chart_power_difference_monster_lme = window.plot().properties(title='LME')
chart_power_difference_monster_lme
# NBVAL_SKIP
chart_power_difference_monster | chart_power_difference_monster_lme